{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import os\n",
    "import torch\n",
    "from utils.utils import read_yaml\n",
    "from utils.dataloader_factory import create_dataloader\n",
    "import warnings\n",
    "from eval_utils import evaluation\n",
    "from model import CHIEF_survival\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "def load_model(cfg):\n",
    "    model = CHIEF_survival(n_classes=cfg.Data.n_classes, **cfg.Model)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model.to(device)\n",
    "    model.train()\n",
    "    return model\n",
    "\n",
    "def eval_single(i, model, dataloader, result_dir, cfg, study):\n",
    "    res_df,cindex = evaluation(i, model, dataloader, result_dir, cfg)\n",
    "    print('cindex:', cindex)\n",
    "    median = res_df['risk'].median()\n",
    "    res_df.loc[res_df['risk']<=median, 'pred'] = 0\n",
    "    res_df.loc[res_df['risk']>median, 'pred'] = 1\n",
    "    return res_df\n",
    "\n",
    "def eval_pipeline(dataset_name, cfg, result_dir):\n",
    "    os.makedirs(result_dir, exist_ok=True)\n",
    "    res_df_total = pd.DataFrame()\n",
    "    dataloader = create_dataloader(0, dataset_name, cfg, result_dir)\n",
    "    model.load_state_dict(\n",
    "        torch.load(ckpt_dir))\n",
    "    res_df = eval_single(0, model, dataloader, result_dir, cfg, dataset_name)\n",
    "    res_df.to_csv('./results/evaluation/tcga_rcc/prediction.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cindex: 0.9\n"
     ]
    }
   ],
   "source": [
    "cfg = read_yaml('./cfg_rcc_cox_dss.yaml')\n",
    "dataset_name = 'val_set'\n",
    "model = load_model(cfg)\n",
    "result_dir1 = os.path.join(cfg.General.result_dir, \n",
    "                            'evaluation', cfg.Data.project)\n",
    "ckpt_dir = os.path.join('./weights/best_checkpoints/s_0_checkpoint.pt')\n",
    "eval_pipeline(dataset_name, cfg, result_dir1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     patient_id  status      risk  pred\n",
      "0  TCGA-B0-5691     0.0 -1.265317   0.0\n",
      "1  TCGA-B0-5713     0.0 -0.398486   0.0\n",
      "2  TCGA-BP-4167     0.0 -0.328713   0.0\n",
      "3  TCGA-BP-4329     1.0  0.913136   1.0\n",
      "4  TCGA-BP-4335     1.0  0.405751   1.0\n",
      "5  TCGA-CJ-4885     0.0  0.262248   1.0\n",
      "6  TCGA-KL-8339     1.0 -0.144757   1.0\n",
      "7  TCGA-KN-8429     0.0 -1.714606   0.0\n",
      "8  TCGA-P4-A5E6     0.0 -0.862301   0.0\n",
      "9  TCGA-P4-AAVL     1.0  2.203031   1.0\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv('./results/evaluation/tcga_rcc/prediction.csv')\n",
    "print(df[['patient_id', 'status', 'risk', 'pred']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
